# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json

from opentelemetry import trace

from llama_stack_api import OpenAIMessageParam, RunShieldResponse

from .constants import (
    RUN_SHIELD_OPERATION_NAME,
    SAFETY_REQUEST_MESSAGES_ATTRIBUTE,
    SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE,
    SAFETY_RESPONSE_METADATA_ATTRIBUTE,
    SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
    SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
)


def safety_span_name(shield_id: str) -> str:
    return f"{RUN_SHIELD_OPERATION_NAME} {shield_id}"


# TODO: Consider using Wrapt to automatically instrument code
# This is the industry standard way to package automatically instrumentation in python.
def safety_request_span_attributes(
    shield_id: str, messages: list[OpenAIMessageParam], response: RunShieldResponse
) -> None:
    span = trace.get_current_span()
    span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
    messages_json = json.dumps([msg.model_dump() for msg in messages])
    span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)

    if response.violation:
        if response.violation.metadata:
            metadata_json = json.dumps(response.violation.metadata)
            span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
        if response.violation.user_message:
            span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
        span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)
